"""
This file implements the model contamination detection through the min-K-prob approach.
https://arxiv.org/pdf/2310.16789.pdf
"""

from utils.tools import *
from tqdm import tqdm
import torch
import numpy as np
from utils.tools import fig_fpr_tpr

def inference(model, tokenizer, sentence, example):
    pred = {}
    p1, all_prob, p1_likelihood = calculatePerplexity(
        sentence, model, tokenizer, gpu=model.device
    )
    pred["ppl"] = p1  # ppl

    # min-k prob
    for ratio in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]:
        k_length = int(len(all_prob) * ratio)
        topk_prob = np.sort(all_prob)[:k_length]
        pred[f"Min_{ratio*100}% Prob"] = -np.mean(topk_prob).item()

    example["pred"] = pred
    return example
        

def mink(data, model, tokenizer, key_name="input", output_dir="results"):
    output_all = []
    for example in tqdm(data):
        text = example[key_name]
        new_ex = inference(model, tokenizer, text, example)
        output_all.append(new_ex)

    output_dir = f"{output_dir}/mink"
    fig_fpr_tpr(output_all, output_dir)
    return output_all


def calculatePerplexity(sentence, model, tokenizer, gpu):
    """
    exp(loss)
    """
    input_ids = torch.tensor(tokenizer.encode(sentence)).unsqueeze(0)
    input_ids = input_ids.to(gpu)
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
    loss, logits = outputs[:2]  # loss, scale
    probabilities = torch.nn.functional.log_softmax(
        logits, dim=-1
    )  # Apply softmax to the logits to get probabilities
    # probabilities = torch.nn.functional.softmax(logits, dim=-1)
    all_prob = []
    input_ids_processed = input_ids[0][1:]
    for i, token_id in enumerate(input_ids_processed):
        probability = probabilities[0, i, token_id].item()
        all_prob.append(probability)
    return torch.exp(loss).item(), all_prob, loss.item()


if __name__ == "__main__":
    pass

